import os
import re
import pprint

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import (accuracy_score, ConfusionMatrixDisplay, classification_report, recall_score,
                             precision_score, f1_score)


task_num_to_task_name = {
    1: '1: Hate Speech'
}

dataset_num_to_dataset_name = {
    1: '1: Finetuning with Experts',
    2: '2: Finetuning with NGO',
    3: '3: Finetuning with Appen',
    4: '4: Finetuning with Citizen Science',
    5: '5: Finetuning with Prolific',
    6: '6: Finetuning with Research Assistants',
    7: '7: Finetuning with GPT Zero Shot'
}

task_to_display_labels = {
    1: {
        'full_name': ['HATE SPEECH', 'KEINE HATE SPEECH', 'TOXIC SPEECH'], 'short_name': ['A', 'B', 'C'],
    }
}

default_metrics = {
    'accuracy': accuracy_score,
    'recall': lambda y_t, y_p: recall_score(y_t, y_p, zero_division='warn', average='micro'),
    'precision': lambda y_t, y_p: precision_score(y_t, y_p, zero_division='warn', average='micro'),
    'f1': lambda y_t, y_p: f1_score(y_t, y_p, zero_division='warn', average='micro')
}


def plot_count_and_normalized_confusion_matrix(y_true, y_pred, display_labels, labels, xticks_rotation='horizontal',
                                               metrics: dict = default_metrics):
    # Print classification report
    cls_report = classification_report(y_true, y_pred, output_dict=True)
    pprint.pprint(cls_report)

    # Create plot with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Remove labels and display_labels not present in y_true
    display_labels = [label for label in display_labels if label in y_true.unique()]
    labels = [label for label in labels if label in y_true.unique()]

    # Plot count confusion matrix
    cm_disp = ConfusionMatrixDisplay.from_predictions(y_true, y_pred, labels=labels, display_labels=display_labels)
    cm_disp.plot(ax=ax1, xticks_rotation=xticks_rotation)
    ax1.set_title('Count Confusion Matrix')

    # Plot normalized confusion matrix
    cm_disp = ConfusionMatrixDisplay.from_predictions(y_true, y_pred, labels=labels, display_labels=display_labels,
                                                      normalize='true')
    cm_disp.plot(ax=ax2, xticks_rotation=xticks_rotation)
    ax2.set_title('Normalized Confusion Matrix')

    # Show plot
    plt.show()
    plt.close()

    # Calculate metrics
    metrics = {metric_name: metric_func(y_true, y_pred) for metric_name, metric_func in metrics.items()}

    return fig, cls_report, metrics


def map_outputs_task_1(output):
    if re.search(r'^(answer:){0,1}(\s)*a(\s)*$|(a(\.|:|\)))|(\s|^|\')hate speech', output.lower().strip()):
        return 'HATE SPEECH'
    elif re.search(r'^(answer:){0,1}(\s)*b(\s)*$|b(\.|:|\))|keine hate speech|no hate speech|\s+b$', output.lower().strip()):
        return 'KEINE HATE SPEECH'
    elif re.search(r'^(answer:){0,1}(\s)*b(\s)*$|b(\.|:|\))|toxic speech|\s+b$', output.lower().strip()):
        return 'TOXIC SPEECH'
    elif output == np.nan or output == 'nan':
        return np.nan
    else:
        print(f'Weird value: {output.lower().strip()}')
        return np.nan



def get_accuracy_accross_tasks(output_dir):
    results_fps = sorted([fp for fp in os.listdir(output_dir) if fp.endswith('csv') and
                          not fp.endswith('_with_mapped_outputs.csv')])

    results = []
    for results_fp in results_fps:
        df = pd.read_csv(os.path.join(output_dir, results_fp))
        dataset_num = int(results_fp.split('__')[0].split('_')[-1])
        task_num = int(results_fp.split('__')[1].split('_')[-1])
        model_name = results_fp.split('__')[2]

        # Filter out any rows without status_id (i.e: not valid rows)
        df = df[df['status_id' if dataset_num != 4 else 'id'].notna()]
        total_rows = df.shape[0]

        # Check if any outputs from the model reached OOM
        oom_df = df[df['model_output'].astype(str) == 'OOM']
        if oom_df.shape[0] > 0:
            os.makedirs(os.path.join(output_dir, 'OOM'), exist_ok=True)
            print('#'* 50)
            print(f'Model {model_name} on dataset {dataset_num}, task {task_num} reached OOM on {oom_df.shape[0]} rows')
            print('#' * 50)
            oom_df.to_csv(os.path.join(output_dir, 'OOM', results_fp), index=False)

            if oom_df.shape[0] == df.shape[0]:
                print('All rows were OOM')
            else:
                print('Around {:.2f}% of rows were OOM'.format(oom_df.shape[0] / df.shape[0] * 100))

            print('skipping this dataset and task\n\n' + '#'* 50 + '\n\n')
            continue

        x_tick_rotation = 'horizontal'

        try:
            if task_num == 1:
                df[f'hatespeech_flan_{model_name}'] = df['model_output'].astype(str).map(map_outputs_task_1)
                y_true = df.loc[df['label_column'].notna().index, 'label_column']\
                    .map(lambda num: 'HATE SPEECH' if num == 1 else 'TOXIC SPEECH' if num == 2 else 'KEINE HATE SPEECH')
                y_pred = df.loc[y_true.index, f'hatespeech_flan_{model_name}']
                display_labels = ['HATE SPEECH', 'TOXIC SPEECH', 'KEINE HATE SPEECH']
                labels = display_labels

            elif task_num == 2:
                df[f'problem_solution_{model_name}'] = df['model_output'].astype(str).map(map_outputs_task_2)
                y_true = df['problem_solution_ra']
                y_pred = df[f'problem_solution_{model_name}']
                display_labels = ['Problem', 'Solution', 'Neither']
                labels = display_labels

            else:
                raise ValueError(f'Unknown task number: {task_num}')

            # Save the results with the mapped model output
            os.makedirs(os.path.join(output_dir, 'mapped_outputs'), exist_ok=True)
            assert df.shape[0] == total_rows
            df.to_csv(os.path.join(output_dir, 'mapped_outputs', results_fp), index=False)

            # Choose valid indexes to use for evaluation
            y_true.dropna(inplace=True)
            total_labeled_rows = y_true.shape[0]
            num_rows_instruction_followed = y_pred[y_true.index][(y_pred.loc[y_true.index].notna()) &
                                                   (y_pred.loc[y_true.index] != 'nan')].shape[0]
            y_pred = y_pred.loc[y_true.index].fillna('NA')

            print('====================')
            print(results_fp)
            print('Dataset:', results_fp.split('__')[0].split('_')[-1])
            print('Model:', model_name)
            print('Task:', task_num)
            print('Accuracy:', accuracy_score(y_true, y_pred))
            print('Percentage of ground truth labeled rows: ', total_labeled_rows / total_rows)
            print('Percentage of rows instruction correctly followed:', num_rows_instruction_followed / total_labeled_rows)

            plot_count_and_normalized_confusion_matrix(y_true, y_pred, display_labels, labels, x_tick_rotation)

            print('====================')
            print('\n\n')

            results.append(
                {
                    'dataset': results_fp.split('__')[0].split('_')[-1],
                    'model': model_name,
                    'task': task_num,
                    'accuracy': accuracy_score(y_true, y_pred),
                    'percentage_of_rows_used': total_labeled_rows / total_rows,
                    'percentage_of_rows_instruction_followed': num_rows_instruction_followed / total_labeled_rows
                }
            )

        except Exception as e:
            print('Problem with:', results_fp)
            print('e:', e)
            print('\n\n')

    results_df = pd.DataFrame(results)
    results_df.set_index(['task', 'dataset', 'model'], inplace=True)
    results_df.sort_index(inplace=True)

    return results_df
